import pandas as p
import numpy as np
from tqdm import tqdm, trange
def loadParameters(travel=True):
    file = 'Tables/iv.xlsx'

    if(travel):
        sheet = 'Spec2'
    else:
        sheet = 'Spec2GasUse'

    data = p.read_excel(file, sheet, names=['param', 'values'], usecols=[0,1], header=None)
    data = data.loc[data.loc[:, 'values']==data.loc[:, 'values'], :]
    dataDict = dict(zip(data.loc[:, 'param'], data.loc[:, 'values']))
    meanVec = [dataDict['pi'], 0]
    cov = [[dataDict['pi variance'], dataDict['pi-delta covariance low']], [dataDict['pi-delta covariance low'], dataDict['delta variance']]]

    r = {'mean':meanVec, 'cov':cov}

    return r

def simulateStateFixedEffects(nStates=51):
    meanVec = [20, 19]
    cov = [[2, 0], [0, 2]]

    effects = np.random.multivariate_normal(meanVec, cov, size=nStates)

    effects = p.DataFrame(effects, columns=['StateFleet', 'StateUse'])
    effects.loc[:, 'State'] = range(nStates)

    return effects

def simulateTimeFixedEffects(start=1970, end=2017):
    times = range(start, end+1)
    nTimes = len(times)

    meanVec = [0,0]
    cov = [[1,0],[0,1]]

    effects = np.random.multivariate_normal(meanVec, cov, size=nTimes)

    effects = p.DataFrame(effects, columns=['TimeFleet', 'TimeUse'])
    effects.loc[:, 'Time'] = times

    return effects

def simulateStateTreatmentEffects(nStates=51, travel=True):
    effectParamDict = loadParameters(travel)

    meanVec = effectParamDict['mean']
    cov = effectParamDict['cov']

    effects = np.random.multivariate_normal(meanVec, cov, size=nStates)

    effects = p.DataFrame(effects, columns=['pi', 'delta'])
    effects.loc[:, 'State'] = effects.index

    return effects

def simulateStateTreatmentYears(nStates=51, start=1970, end=2017):
    r = np.random.randint(start, end, size=nStates)
    r = dict(zip(range(nStates), r))

    r = p.DataFrame.from_dict(r, orient='index')
    r = r.rename({0: 'StateTreatTime'}, axis=1)
    r.loc[:, 'State'] = r.index

    r.loc[:, 'TreatType'] = np.argmax(np.random.multinomial(1, [.5, .25, .25], size=nStates), axis=1, keepdims=True)

    r.loc[:, 'TreatType'] = r.loc[:, 'TreatType'].replace(dict(zip([0,1,2], ['Always', 'Never', 'Treated'])))

    r.loc[r.loc[:, 'TreatType']=='Never', 'StateTreatTime'] = 9999999999
    r.loc[r.loc[:, 'TreatType'] == 'Always', 'StateTreatTime'] = 0

    return r
def constructBaseDF(nStates=51, start=1970, end=2017):

    rows = []

    for i in range(nStates):
        for j in range(start, end+1):
            tempRow = [i, j]
            rows.append(tempRow)

    r = p.DataFrame(rows, columns=['State', 'Time'])
    return r

def simulateDataset(nStates=51, start=1970, end=2017, travel=True):
    baseDF = constructBaseDF(nStates, start, end)
    stateFEs = simulateStateFixedEffects(nStates)
    stateTTs = simulateStateTreatmentYears(nStates, start, end)
    stateTEs = simulateStateTreatmentEffects(nStates, travel)
    timeFEs = simulateTimeFixedEffects(start, end)

    baseDF = baseDF.merge(stateFEs, on='State')
    baseDF = baseDF.merge(stateTTs, on='State')
    baseDF = baseDF.merge(stateTEs, on='State')
    baseDF = baseDF.merge(timeFEs, on='Time')

    baseDF.loc[:, 'FleetError'] = np.random.normal(size=len(baseDF))
    baseDF.loc[:, 'UseError'] = np.random.normal(size=len(baseDF))

    baseDF.loc[:, 'TreatInd'] = baseDF.loc[:, 'Time']>baseDF.loc[:, 'StateTreatTime']
    baseDF.loc[:, 'TreatInd'] = baseDF.loc[:, 'TreatInd'].astype(int)
    baseDF = baseDF.sort_values(['State', 'Time'])


    baseDF.loc[:, 'Fleet'] = baseDF.loc[:, 'FleetError']+baseDF.loc[:, 'StateFleet']+baseDF.loc[:, 'TimeFleet']+(baseDF.loc[:, 'pi']*baseDF.loc[:, 'TreatInd'])
    baseDF.loc[:, 'Use'] = baseDF.loc[:, 'UseError']+baseDF.loc[:, 'StateUse'] + baseDF.loc[:, 'TimeUse'] + (
                baseDF.loc[:, 'delta'] * baseDF.loc[:, 'TreatInd'])

    return baseDF


np.random.seed(1995)
xlw = p.ExcelWriter('Data/SimData.xlsx')
for travel in [True, False]:
    for i in trange(1, 101):
        tempDataset = simulateDataset(travel=travel)
        if(travel):
            sheetName='Travel-{}'.format(i)
        else:
            sheetName = 'Gas-{}'.format(i)
        tempDataset.to_excel(xlw, sheetName, index=False)


xlw.close()